Image Embeddings¶

  • Generate image embeddings: Use a pre-trained neural network model to extract features from real-world images.
    • Image Embedding: Use a pre-trained model like ResNet50 (from TensorFlow or PyTorch) to extract embeddings for each image.
  • Store embeddings in FAISS: Save these embeddings in a FAISS index for efficient similarity search.
    • Store the image embeddings in a FAISS index.
    • Each image is loaded, resized, and processed to generate an embedding vector using the generate_image_embedding function. This function utilizes ResNet50's convolutional layers to extract meaningful features.
    • The embeddings for all images are stored in a FAISS index. FAISS is an efficient similarity search library, which allows fast retrieval of similar vectors using techniques like L2 (Euclidean) distance.
  • Implement RAG: Query the stored embeddings based on a user query and retrieve the top-k most similar images.
    • implement a simple mechanism RAG to take a user query (which can be text) and retrieve the top-k most relevant images based on their embeddings.
    • When a user provides a query image, we generate its embedding and search the FAISS index for the top-k most similar images based on the L2 distance. The function query_faiss_index handles this.
  • Outputs:
    • The first table shows the original images, their indices, and the first 5 values of their embeddings.
    • The second table shows the predicted top-k images based on the user query, along with their indices and embeddings.
    • The images themselves are displayed alongside the text data.
In [ ]:
%pip install -q tensorflow keras-resnet faiss-cpu pandas numpy matplotlib
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
#### best 2 images
import numpy as np
import pandas as pd
import faiss
import os
import matplotlib.pyplot as plt
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input

# Load ResNet50 model pre-trained on ImageNet (without top classification layer)
def load_resnet50_model():
    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
    return model

# Generate image embedding using ResNet50
def generate_image_embedding(img_path, model):
    img = image.load_img(img_path, target_size=(224, 224))  # Resize image to (224, 224)
    img_array = image.img_to_array(img)  # Convert image to array
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension
    img_array = preprocess_input(img_array)  # Preprocess for ResNet50
    embedding = model.predict(img_array)  # Get embedding
    return embedding.flatten()  # Flatten to 1D vector

# Create FAISS index to store embeddings
def create_faiss_index(embeddings):
    dim = embeddings.shape[1]  # Dimensionality of the embeddings
    index = faiss.IndexFlatL2(dim)  # Use L2 distance (Euclidean distance)
    index.add(embeddings)  # Add embeddings to the FAISS index
    return index

# Query FAISS index and retrieve top k similar images
def query_faiss_index(query_embedding, index, k=2):
    distances, indices = index.search(query_embedding.reshape(1, -1), k)  # Get top-k neighbors
    return distances, indices

# Display image grid for top-k matches
def display_images(images, titles, embeddings, indices):
    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    for i, ax in enumerate(axes):
        ax.imshow(images[i])
        ax.set_title(f"Index: {indices[i]}\nEmbedding: {embeddings[i][:5]}")
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# Display table with image metadata (index and embeddings)
def display_image_table(data, title):
    df = pd.DataFrame(data, columns=['Index', 'Image', 'Embedding'])
    print(f"\n{title}")
    print(df.to_markdown(index=False))  # Use markdown to get nice table formatting

# Main function to integrate the process
def main():
    # Load ResNet50 model
    model = load_resnet50_model()

    # Path to images directory (use user input or predefined path)
    image_directory = '/curated/ImageStore/images/'

    # Check if the directory exists
    if not os.path.exists(image_directory):
        print(f"Error: The directory '{image_directory}' does not exist.")
        return

    # List all image files in the directory and filter by extensions
    image_files = [f for f in os.listdir(image_directory) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

    # Construct full image paths
    image_paths = [os.path.join(image_directory, img) for img in image_files]

    # Generate embeddings for the images
    embeddings = []
    for img_path in image_paths:
        embedding = generate_image_embedding(img_path, model)
        embeddings.append(embedding)
    embeddings = np.array(embeddings)  # Convert list to array for FAISS

    # Create FAISS index
    index = create_faiss_index(embeddings)

    # Display stored images, their index, and embeddings in a table
    stored_image_data = [(i, os.path.basename(image_paths[i]), embeddings[i][:5]) for i in range(len(image_paths))]
    display_image_table(stored_image_data, "Stored Images and Embeddings")

    # User query image: get the image name from input and add to the directory path
    user_query_image_name = 'elephant1.png'  # Hardcoded query image for now (can be replaced with user input)
    user_query_image = os.path.join(image_directory, user_query_image_name)

    # Check if the provided query image exists in the directory
    if not os.path.exists(user_query_image):
        print(f"Error: The image '{user_query_image_name}' was not found in the specified directory.")
        return

    # Generate query image embedding
    query_embedding = generate_image_embedding(user_query_image, model)

    # Query FAISS index for top-2 similar images (k=2)
    k = 2  # Set k to 2 to get the top two matches
    distances, indices = query_faiss_index(query_embedding, index, k=k)

    # Retrieve top-2 images and their metadata
    top_k_images = [image.load_img(image_paths[i], target_size=(224, 224)) for i in indices[0]]
    top_k_embeddings = [embeddings[i] for i in indices[0]]
    top_k_indices = indices[0]

    # Display the query image
    query_image = image.load_img(user_query_image, target_size=(224, 224))
    plt.imshow(query_image)
    plt.title(f"User Query: {user_query_image_name}")
    plt.axis('off')
    plt.show()

    # Display predicted top-2 images, their index and embeddings
    predicted_image_data = [(top_k_indices[i], os.path.basename(image_paths[top_k_indices[i]]), top_k_embeddings[i][:5]) for i in range(k)]
    display_image_table(predicted_image_data, f"Predicted Top-2 Images and Embeddings for: {user_query_image_name}")

    # Display the top-2 images
    display_images(top_k_images, [f"Index: {i}" for i in top_k_indices], top_k_embeddings, top_k_indices)    

if __name__ == "__main__":
    main()
1/1 [==============================] - 1s 752ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 50ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 51ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 50ms/step
1/1 [==============================] - 0s 49ms/step
1/1 [==============================] - 0s 48ms/step
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 47ms/step
1/1 [==============================] - 0s 51ms/step
1/1 [==============================] - 0s 48ms/step

Stored Images and Embeddings
|   Index | Image              | Embedding                                                |
|--------:|:-------------------|:---------------------------------------------------------|
|       0 | airplane.jpg       | [0.01681192 1.3374338  0.4865353  0.02957397 0.2203022 ] |
|       1 | car.jpg            | [0.02365123 0.21462531 0.39598972 0.04758341 0.13171911] |
|       2 | elephant1.png      | [0.30503857 2.5281742  0.13030702 0.40003565 0.8225097 ] |
|       3 | elephant3.png      | [0.62227046 1.1155356  0.10596917 0.03490202 0.40784183] |
|       4 | elephant_face2.png | [0.03081354 2.703372   0.10826059 0.06344301 1.4554362 ] |
|       5 | fighter_jet1.png   | [2.3893921  0.37274185 0.21137363 0.01661964 0.        ] |
|       6 | fighter_jet2.png   | [1.2492405  0.21738124 1.0315228  0.         0.23579091] |
|       7 | fighter_jet3.png   | [0.94846344 0.21674678 1.0023813  0.         0.11890657] |
|       8 | kangaroo1.png      | [0.89772236 0.94619936 0.5298464  0.25724941 0.05652738] |
|       9 | kangaroo2.png      | [0.869192   0.6844135  0.53440905 0.09354267 1.3998648 ] |
|      10 | kangaroo3.png      | [0.2655031  0.6898773  0.22355784 0.08943271 0.07943194] |
|      11 | koala1.png         | [1.0157447  1.2002645  0.07522814 0.25566643 0.        ] |
|      12 | koala2.png         | [0.44358256 2.0638452  0.25086418 1.0025302  0.00942492] |
|      13 | koala3.png         | [1.1000643  0.8675375  0.29340845 0.02682502 0.2982855 ] |
|      14 | koala4.png         | [0.44358256 2.0638452  0.25086418 1.0025302  0.00942492] |
1/1 [==============================] - 0s 48ms/step
No description has been provided for this image
Predicted Top-2 Images and Embeddings for: elephant1.png
|   Index | Image              | Embedding                                                |
|--------:|:-------------------|:---------------------------------------------------------|
|       2 | elephant1.png      | [0.30503857 2.5281742  0.13030702 0.40003565 0.8225097 ] |
|       4 | elephant_face2.png | [0.03081354 2.703372   0.10826059 0.06344301 1.4554362 ] |
No description has been provided for this image